# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for test."""
import numpy as np


def get_init_params(seq_len=2, batch_size=2, hidden_size=32):
    # Generate initialization parameters
    np.random.seed(1)
    return {
        "inputs": 0.01 * np.random.randn(seq_len, batch_size, hidden_size),
    }


def get_golden_datas() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    output_layer_norm = np.array(
        [[[1.6576998234, -0.5042570829, -0.4234439433, -0.9501764178,
           0.9239270091, -2.1380095482, 1.7741717100, -0.6487520933,
           0.3956750035, -0.1538870931, 1.5008417368, -1.9046156406,
           -0.2245118022, -0.2841052115, 1.1833903790, -0.9762063622,
           -0.0794961825, -0.7615356445, 0.1280286461, 0.6507047415,
           -0.9769100547, 1.1939814091, 0.9589103460, 0.5730472207,
           0.9581998587, -0.5738421082, -0.0316007882, -0.8175264597,
           -0.1717907786, 0.5999845266, -0.5815119743, -0.2963832915],
          [-0.9155051708, -1.0929702520, -0.8976202011, -0.1580575854,
           -1.3985339403, 0.1194044501, 1.7200605869, 0.6894524097,
           -0.3592599332, -1.1406100988, -0.9828668237, 1.7567279339,
           -0.0867804810, -0.8591581583, 0.0705553070, 2.2146730423,
           -0.0089016603, 0.5492604971, 0.1932444423, -0.5393992662,
           -1.4268413782, -0.5361346602, -0.3784161806, 0.5149204135,
           0.7983115911, 0.9017573595, 0.1768682748, 0.8501450419,
           -0.9909966588, 1.2630887032, 0.4321654737, -0.4785828888]],
         [[0.3653440773, -0.2589111924, 1.0770491362, 1.5066404343,
           2.2434084415, -1.7207247019, -1.7734210491, -0.7335509062,
           0.0018274576, 0.7943414450, 0.1740213037, -2.4131667614,
           -0.5141425729, 0.7410067916, 0.0793574303, 0.6680076718,
           -0.4213205576, -0.3974498510, 0.0311808344, 0.2785084248,
           0.0441711694, -0.0435770340, -0.9174737930, 0.2425554097,
           -0.0404644236, 1.0746748447, 1.1515146494, 0.0296260100,
           -0.5905916691, -0.8821360469, 0.2933849096, -0.0896899551],
          [-0.3989730179, -0.0128486855, -0.6741749644, 0.6393464208,
           -0.5018944144, 1.1640199423, 0.3458141983, 0.5352504849,
           -1.1474603415, 0.1125063300, 0.6817252636, -1.0067324638,
           -0.3216035664, -0.0237934031, -1.4247134924, 0.2577843070,
           0.7869680524, -0.9128701687, 0.2930497825, -1.3640878201,
           -0.0948593691, -1.6665381193, 1.0612828732, 0.3512045741,
           -0.0808290094, -0.8288046122, 1.2130995989, 1.9040720463,
           -1.9079184532, 1.1756364107, 1.5657829046, 0.2805584073]]])
    output_rms_norm = np.array(
        [[[1.5645461082, -0.5892349482, -0.5087274313, -1.0334678888,
           0.8335481882, -2.2168090343, 1.6805775166, -0.7331835032,
           0.3072938621, -0.2401899546, 1.4082812071, -1.9842977524,
           -0.3105475903, -0.3699156046, 1.0920304060, -1.0593994856,
           -0.1660803556, -0.8455405831, 0.0406596735, 0.5613591671,
           -1.0601004362, 1.1025813818, 0.8683992028, 0.4839953184,
           0.8676914573, -0.6585568190, -0.1183660924, -0.9013196230,
           -0.2580259144, 0.5108307600, -0.6661976576, -0.3821472824],
          [-0.7638087273, -0.9394659996, -0.7461059690, -0.0140770013,
           -1.2419168949, 0.2605586052, 1.8449094296, 0.8247996569,
           -0.2132297456, -0.9866205454, -0.8304840922, 1.8812032938,
           0.0564740226, -0.7080357075, 0.2122070789, 2.3344833851,
           0.1335595101, 0.6860358715, 0.3336464167, -0.3915340602,
           -1.2699360847, -0.3883027136, -0.2321908772, 0.6520455480,
           0.9325499535, 1.0349419117, 0.3174370527, 0.9838553667,
           -0.8385311365, 1.3925925493, 0.5701336265, -0.3313372433]],
         [[0.5325049758, -0.0823762938, 1.2335228920, 1.6566632986,
           2.3823678493, -1.5222388506, -1.5741438866, -0.5498886704,
           0.1744470447, 0.9550604224, 0.3440551758, -2.2042829990,
           -0.3337750435, 0.9025266767, 0.2508127987, 0.8306237459,
           -0.2423468828, -0.2188346088, 0.2033596337, 0.4469732940,
           0.2161549032, 0.1297243536, -0.7310496569, 0.4115601778,
           0.1327902228, 1.2311842442, 1.3068702221, 0.2018281668,
           -0.4090761542, -0.6962426305, 0.4616263807, 0.0843038782],
          [-0.3421349227, 0.0433789380, -0.6169018149, 0.6945429444,
           -0.4448935986, 1.2183870077, 0.4014748037, 0.5906115770,
           -1.0894389153, 0.1685357839, 0.7368547916, -0.9489335418,
           -0.2648878396, 0.0324515253, -1.3662538528, 0.3135840893,
           0.8419311643, -0.8552196622, 0.3487937748, -1.3057240248,
           -0.0385020897, -1.6076960564, 1.1158123016, 0.4068566561,
           -0.0244939104, -0.7712869644, 1.2673890591, 1.9572691917,
           -1.8486948013, 1.2299851179, 1.6195149422, 0.3363221586]]])
    return {
        "LayerNorm": output_layer_norm,
        "RMSNorm": output_rms_norm,
    }


def get_gpu_datas() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    output_layer_norm = np.array(
        [[[1.6576998234, -0.5042570829, -0.4234439433, -0.9501764178,
           0.9239270091, -2.1380095482, 1.7741717100, -0.6487520933,
           0.3956750035, -0.1538870931, 1.5008417368, -1.9046156406,
           -0.2245118022, -0.2841052115, 1.1833903790, -0.9762063622,
           -0.0794961825, -0.7615356445, 0.1280286461, 0.6507047415,
           -0.9769100547, 1.1939814091, 0.9589103460, 0.5730472207,
           0.9581998587, -0.5738421082, -0.0316007882, -0.8175264597,
           -0.1717907786, 0.5999845266, -0.5815119743, -0.2963832915],
          [-0.9155051708, -1.0929702520, -0.8976202011, -0.1580575854,
           -1.3985339403, 0.1194044501, 1.7200605869, 0.6894524097,
           -0.3592599332, -1.1406100988, -0.9828668237, 1.7567279339,
           -0.0867804810, -0.8591581583, 0.0705553070, 2.2146730423,
           -0.0089016603, 0.5492604971, 0.1932444423, -0.5393992662,
           -1.4268413782, -0.5361346602, -0.3784161806, 0.5149204135,
           0.7983115911, 0.9017573595, 0.1768682748, 0.8501450419,
           -0.9909966588, 1.2630887032, 0.4321654737, -0.4785828888]],
         [[0.3653440773, -0.2589111924, 1.0770491362, 1.5066404343,
           2.2434084415, -1.7207247019, -1.7734210491, -0.7335509062,
           0.0018274576, 0.7943414450, 0.1740213037, -2.4131667614,
           -0.5141425729, 0.7410067916, 0.0793574303, 0.6680076718,
           -0.4213205576, -0.3974498510, 0.0311808344, 0.2785084248,
           0.0441711694, -0.0435770340, -0.9174737930, 0.2425554097,
           -0.0404644236, 1.0746748447, 1.1515146494, 0.0296260100,
           -0.5905916691, -0.8821360469, 0.2933849096, -0.0896899551],
          [-0.3989730179, -0.0128486855, -0.6741749644, 0.6393464208,
           -0.5018944144, 1.1640199423, 0.3458141983, 0.5352504849,
           -1.1474603415, 0.1125063300, 0.6817252636, -1.0067324638,
           -0.3216035664, -0.0237934031, -1.4247134924, 0.2577843070,
           0.7869680524, -0.9128701687, 0.2930497825, -1.3640878201,
           -0.0948593691, -1.6665381193, 1.0612828732, 0.3512045741,
           -0.0808290094, -0.8288046122, 1.2130995989, 1.9040720463,
           -1.9079184532, 1.1756364107, 1.5657829046, 0.2805584073]]])
    output_rms_norm = np.array(
        [[[1.5645461082, -0.5892349482, -0.5087274313, -1.0334678888,
           0.8335481882, -2.2168090343, 1.6805775166, -0.7331835032,
           0.3072938621, -0.2401899546, 1.4082812071, -1.9842977524,
           -0.3105475903, -0.3699156046, 1.0920304060, -1.0593994856,
           -0.1660803556, -0.8455405831, 0.0406596735, 0.5613591671,
           -1.0601004362, 1.1025813818, 0.8683992028, 0.4839953184,
           0.8676914573, -0.6585568190, -0.1183660924, -0.9013196230,
           -0.2580259144, 0.5108307600, -0.6661976576, -0.3821472824],
          [-0.7638087273, -0.9394659996, -0.7461059690, -0.0140770013,
           -1.2419168949, 0.2605586052, 1.8449094296, 0.8247996569,
           -0.2132297456, -0.9866205454, -0.8304840922, 1.8812032938,
           0.0564740226, -0.7080357075, 0.2122070789, 2.3344833851,
           0.1335595101, 0.6860358715, 0.3336464167, -0.3915340602,
           -1.2699360847, -0.3883027136, -0.2321908772, 0.6520455480,
           0.9325499535, 1.0349419117, 0.3174370527, 0.9838553667,
           -0.8385311365, 1.3925925493, 0.5701336265, -0.3313372433]],
         [[0.5325049758, -0.0823762938, 1.2335228920, 1.6566632986,
           2.3823678493, -1.5222388506, -1.5741438866, -0.5498886704,
           0.1744470447, 0.9550604224, 0.3440551758, -2.2042829990,
           -0.3337750435, 0.9025266767, 0.2508127987, 0.8306237459,
           -0.2423468828, -0.2188346088, 0.2033596337, 0.4469732940,
           0.2161549032, 0.1297243536, -0.7310496569, 0.4115601778,
           0.1327902228, 1.2311842442, 1.3068702221, 0.2018281668,
           -0.4090761542, -0.6962426305, 0.4616263807, 0.0843038782],
          [-0.3421349227, 0.0433789380, -0.6169018149, 0.6945429444,
           -0.4448935986, 1.2183870077, 0.4014748037, 0.5906115770,
           -1.0894389153, 0.1685357839, 0.7368547916, -0.9489335418,
           -0.2648878396, 0.0324515253, -1.3662538528, 0.3135840893,
           0.8419311643, -0.8552196622, 0.3487937748, -1.3057240248,
           -0.0385020897, -1.6076960564, 1.1158123016, 0.4068566561,
           -0.0244939104, -0.7712869644, 1.2673890591, 1.9572691917,
           -1.8486948013, 1.2299851179, 1.6195149422, 0.3363221586]]])
    return {
        "LayerNorm": output_layer_norm,
        "RMSNorm": output_rms_norm,
    }


GOLDEN_DATA = get_golden_datas()
GPU_DATA = get_gpu_datas()
